{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import open3d as o3d\n",
    "import torch\n",
    "import numpy as np\n",
    "import os\n",
    "import os.path as osp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "  <img width=\"40%\" src=\"https://raw.githubusercontent.com/nicolas-chaulet/torch-points3d/master/docs/logo.png\" />\n",
    "</p>\n",
    "\n",
    "# Registration Demo on KITTI Odometry\n",
    "\n",
    "In this task, we will show a demonstration of registration on KITTI odometry using a pretrained network. from scratch\n",
    "\n",
    "First let's load some examples. We multiply by the calibration matrix to be correctly oriented"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We read the data\n",
    "path_s = \"data/KITTI/000000.bin\"\n",
    "path_t = \"data/KITTI/000049.bin\"\n",
    "R_calib = np.asarray([[-1.857739385241e-03, -9.999659513510e-01, -8.039975204516e-03, -4.784029760483e-03],\n",
    "                      [-6.481465826011e-03, 8.051860151134e-03, -9.999466081774e-01, -7.337429464231e-02],\n",
    "                      [9.999773098287e-01, -1.805528627661e-03, -6.496203536139e-03, -3.339968064433e-01]])\n",
    "pcd_s = np.fromfile(path_s, dtype=np.float32).reshape(-1, 4)[:, :3].dot(R_calib[:3, :3].T)\n",
    "pcd_t = np.fromfile(path_t, dtype=np.float32).reshape(-1, 4)[:, :3].dot(R_calib[:3, :3].T)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can put the point cloud in the class Batch, apply some transformation (transform data into sparse voxels, add ones). We can load the model too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data preprocessing import\n",
    "from torch_points3d.core.data_transform import GridSampling3D, AddOnes, AddFeatByKey\n",
    "from torch_geometric.transforms import Compose\n",
    "from torch_geometric.data import Batch\n",
    "\n",
    "# Model\n",
    "from torch_points3d.applications.pretrained_api import PretainedRegistry\n",
    "\n",
    "# post processing\n",
    "from torch_points3d.utils.registration import get_matches, fast_global_registration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = Compose([GridSampling3D(mode='last', size=0.3, quantize_coords=True), AddOnes(), AddFeatByKey(add_to_x=True, feat_name=\"ones\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_s = transform(Batch(pos=torch.from_numpy(pcd_s).float(), batch=torch.zeros(pcd_s.shape[0]).long()))\n",
    "data_t = transform(Batch(pos=torch.from_numpy(pcd_t).float(), batch=torch.zeros(pcd_t.shape[0]).long()))\n",
    "\n",
    "\n",
    "\n",
    "model = PretainedRegistry.from_pretrained(\"minkowski-registration-kitti\").cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "o3d_pcd_s = o3d.geometry.PointCloud()\n",
    "o3d_pcd_s.points = o3d.utility.Vector3dVector(data_s.pos.cpu().numpy())\n",
    "o3d_pcd_s.paint_uniform_color([0.9, 0.7, 0.1])\n",
    "\n",
    "o3d_pcd_t = o3d.geometry.PointCloud()\n",
    "o3d_pcd_t.points = o3d.utility.Vector3dVector(data_t.pos.cpu().numpy())\n",
    "o3d_pcd_t.paint_uniform_color([0.1, 0.7, 0.9])\n",
    "# visualizer = o3d.JVisualizer()\n",
    "# visualizer.add_geometry(o3d_pcd_s)\n",
    "# visualizer.add_geometry(o3d_pcd_t)\n",
    "# visualizer.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    model.set_input(data_s, \"cuda\")\n",
    "    output_s = model.forward()\n",
    "    model.set_input(data_t, \"cuda\")\n",
    "    output_t = model.forward()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " Now we have our feature let's match our features. We will select 5000 points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rand_s = torch.randint(0, len(output_s), (5000, ))\n",
    "rand_t = torch.randint(0, len(output_t), (5000, ))\n",
    "\n",
    "matches = get_matches(output_s[rand_s], output_t[rand_t])\n",
    "\n",
    "T_est = fast_global_registration(data_s.pos[rand_s][matches[:, 0]], data_t.pos[rand_t][matches[:, 1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "visualizer = o3d.JVisualizer()\n",
    "visualizer.add_geometry(o3d_pcd_s.transform(T_est.cpu().numpy()))\n",
    "visualizer.add_geometry(o3d_pcd_t)\n",
    "visualizer.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
